/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
/*!
 * \file batch_data_copy_wrapper.h
 * \brief
 */

#ifndef IMPL_MATMUL_MODULES_STAGE_BATCH_COPY_CUBE_IN_BATCH_DATA_COPY_WRAPPER_H
#define IMPL_MATMUL_MODULES_STAGE_BATCH_COPY_CUBE_IN_BATCH_DATA_COPY_WRAPPER_H

#include "../../../matmul_module.h"
#include "../../../matmul_param.h"
#include "../copy_cube_in_utils.h"
#include "../copy_cube_in_params.h"

namespace AscendC {
namespace Impl {
namespace Detail {

template <typename IMPL, const auto& MM_CFG, class INPUT_TYPE>
class BatchDataCopyWrapper
{
    using TransT = typename INPUT_TYPE::TRANS_T;
    using SrcT = typename INPUT_TYPE::T;

    MATMUL_USE_MODULE_ON(CopyCubeInParams, INPUT_TYPE::TAG);
    MATMUL_USE_MODULE_ON(MatmulTensorInfo, INPUT_TYPE::TAG);
    MATMUL_USE_MODULE(MatmulShapeTiling);
    MATMUL_USE_MODULE(MatmulShapeInfo);
    MATMUL_USE_MODULE(MatmulUserDefineInfo);
    MATMUL_USE_MODULE(LocalWorkspace);

    template <bool IS_TRANS = false, typename INPUT_TYPE_ALIAS = INPUT_TYPE>
    __aicore__ constexpr enable_if_t<INPUT_TYPE_ALIAS::TAG == InputTypeTag::A, int32_t> GetStaticTileHeight() const
    {
        if constexpr ((INPUT_TYPE_ALIAS::layout != LayoutMode::NONE) &&
            (ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1)) {
            if constexpr (IS_TRANS) {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreK();
            } else {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreM();
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            if constexpr (IS_TRANS) {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetStepKa() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseK();
            } else {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetStepM() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseM();
            }
        } else {
            return MATMUL_MODULE(CopyCubeInParams)->template GetBaseHeight<IS_TRANS>();
        }
    }

    template <bool IS_TRANS = false, typename INPUT_TYPE_ALIAS = INPUT_TYPE>
    __aicore__ constexpr enable_if_t<INPUT_TYPE_ALIAS::TAG == InputTypeTag::A, int32_t> GetStaticTileWidth() const
    {
        if constexpr ((INPUT_TYPE_ALIAS::layout != LayoutMode::NONE) &&
            (ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1)) {
            if constexpr (IS_TRANS) {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreM();
            } else {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreK();
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            if constexpr (IS_TRANS) {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetStepM() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseM();
            } else {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetStepKa() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseK();
            }
        } else {
            return MATMUL_MODULE(CopyCubeInParams)->template GetBaseWidth<IS_TRANS>();
        }
    }

    template <bool IS_TRANS = false, typename INPUT_TYPE_ALIAS = INPUT_TYPE>
    __aicore__ inline enable_if_t<INPUT_TYPE_ALIAS::TAG == InputTypeTag::B, int32_t> GetStaticTileHeight() const
    {
        if constexpr ((INPUT_TYPE_ALIAS::layout != LayoutMode::NONE) &&
            (ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1)) {
            if constexpr (IS_TRANS) {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN();
            } else {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreK();
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            if constexpr (IS_TRANS) {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetStepN() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN();
            } else {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetStepKb() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseK();
            }
        } else {
            return MATMUL_MODULE(CopyCubeInParams)->template GetBaseHeight<IS_TRANS>();
        }
    }

    template <bool IS_TRANS = false, typename INPUT_TYPE_ALIAS = INPUT_TYPE>
    __aicore__ inline enable_if_t<INPUT_TYPE_ALIAS::TAG == InputTypeTag::B, int32_t> GetStaticTileWidth() const
    {
        if constexpr ((INPUT_TYPE_ALIAS::layout != LayoutMode::NONE) &&
            (ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1)) {
            if constexpr (IS_TRANS) {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreK();
            } else {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN();
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            if constexpr (IS_TRANS) {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetStepKb() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseK();
            } else {
                return MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetStepN() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN();
            }
        } else {
            return MATMUL_MODULE(CopyCubeInParams)->template GetBaseWidth<IS_TRANS>();
        }
    }

public:
    __aicore__ inline BatchDataCopyWrapper() = default;
    __aicore__ inline ~BatchDataCopyWrapper() = default;

    __aicore__ inline void BatchCopyND2NZ(const LocalTensor<TransT>& dst, const GlobalTensor<SrcT>& src, const int row,
                                          const int col, const int height, const int width, const int gCol,
                                          const int ndNum = 1, const int srcNdMatrixStride = 0,
                                          const int dstNzMatrixStride = 0, const bool kAlignToC0Size = false)
    {
#ifdef ASCENDC_CPU_DEBUG
        if (INPUT_TYPE::TAG == InputTypeTag::A && IMPL::CallBack::CopyA1Ptr) {
            LocalTensor<int8_t> a1Tmp = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyA1Ptr)(a1Tmp, reinterpret_cast<__gm__ void *>(src.address_), row, col, height, width,
                                        MATMUL_MODULE(MatmulUserDefineInfo)->GetUserDefineInfo(), MATMUL_MODULE(MatmulUserDefineInfo)->GetSelfDefineData());
        } else if (INPUT_TYPE::TAG == InputTypeTag::B && IMPL::CallBack::CopyB1Ptr) {
            LocalTensor<int8_t> a1Tmp = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyB1Ptr)(a1Tmp, reinterpret_cast<__gm__ void *>(src.address_), row, col, height, width,
                                        MATMUL_MODULE(MatmulUserDefineInfo)->GetUserDefineInfo(), MATMUL_MODULE(MatmulUserDefineInfo)->GetSelfDefineData());
#else
        if constexpr (INPUT_TYPE::TAG == InputTypeTag::A && IMPL::CallBack::CopyA1Ptr) {
            LocalTensor<int8_t> a1Tmp = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyA1Ptr)(a1Tmp, reinterpret_cast<__gm__ void *>(src.address_), row, col, height, width,
                                        MATMUL_MODULE(MatmulUserDefineInfo)->GetUserDefineInfo(), MATMUL_MODULE(MatmulUserDefineInfo)->GetSelfDefineData());
        } else if constexpr (INPUT_TYPE::TAG == InputTypeTag::B && IMPL::CallBack::CopyB1Ptr) {
            LocalTensor<int8_t> a1Tmp = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyB1Ptr)(a1Tmp, reinterpret_cast<__gm__ void *>(src.address_), row, col, height, width,
                                        MATMUL_MODULE(MatmulUserDefineInfo)->GetUserDefineInfo(), MATMUL_MODULE(MatmulUserDefineInfo)->GetSelfDefineData());
#endif
        } else {
#if __CCE_AICORE__ >= 220
            CopyND2NZ(dst, src, row, col, height, width, gCol, ndNum, srcNdMatrixStride, dstNzMatrixStride,
                      kAlignToC0Size);
#endif
        }
    }

    __aicore__ inline void BatchCopyNZ2NZ(const LocalTensor<TransT>& dst, const LocalTensor<TransT>& src, int row,
                                          int col, int height, int width, int gRow, bool kAlignToC0Size = false)
    {
        CopyNZ2NZ(dst, src, row, col, height, width, gRow);
    }

    __aicore__ inline void BatchCopyNZ2NZ(const LocalTensor<TransT>& dst, const GlobalTensor<TransT>& src,
                                          const int row, const int col, const int height, const int width,
                                          const int gRow, const bool kAlignToC0Size = false)
    {
#ifdef ASCENDC_CPU_DEBUG
        if (INPUT_TYPE::TAG == InputTypeTag::A && IMPL::CallBack::CopyA1Ptr) {
            LocalTensor<int8_t> a1Tmp = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyA1Ptr)(a1Tmp, reinterpret_cast<__gm__ void *>(src.address_), row, col, height, width,
                                        MATMUL_MODULE(MatmulUserDefineInfo)->GetUserDefineInfo(), MATMUL_MODULE(MatmulUserDefineInfo)->GetSelfDefineData());
        } else if (INPUT_TYPE::TAG == InputTypeTag::B && IMPL::CallBack::CopyB1Ptr) {
            LocalTensor<int8_t> a1Tmp = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyB1Ptr)(a1Tmp, reinterpret_cast<__gm__ void *>(src.address_), row, col, height, width,
                                        MATMUL_MODULE(MatmulUserDefineInfo)->GetUserDefineInfo(), MATMUL_MODULE(MatmulUserDefineInfo)->GetSelfDefineData());
#else
        if constexpr (INPUT_TYPE::TAG == InputTypeTag::A && IMPL::CallBack::CopyA1Ptr) {
            LocalTensor<int8_t> a1Tmp = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyA1Ptr)(a1Tmp, reinterpret_cast<__gm__ void *>(src.address_), row, col, height, width,
                                        MATMUL_MODULE(MatmulUserDefineInfo)->GetUserDefineInfo(), MATMUL_MODULE(MatmulUserDefineInfo)->GetSelfDefineData());
        } else if constexpr (INPUT_TYPE::TAG == InputTypeTag::B && IMPL::CallBack::CopyB1Ptr) {
            LocalTensor<int8_t> a1Tmp = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyB1Ptr)(a1Tmp, reinterpret_cast<__gm__ void *>(src.address_), row, col, height, width,
                                        MATMUL_MODULE(MatmulUserDefineInfo)->GetUserDefineInfo(), MATMUL_MODULE(MatmulUserDefineInfo)->GetSelfDefineData());
#endif
        } else {
            CopyNZ2NZ(dst, src, row, col, height, width, gRow, kAlignToC0Size);
        }
    }

#if __CCE_AICORE__ < 220
    __aicore__ inline void CopyND2NZOnTheFly(const LocalTensor<TransT>& dst, const GlobalTensor<SrcT>& src, int row,
                                             int col, int height, int width, int gCol)
    {
        ASSERT(gCol >= width && "Copy ND block gm->ub width larger than origin matrix width.");
        int calcWidth = width / c0Size_; // cube block numbers that do not need to be pad zero
        int dstOffset = 0;
        int64_t srcOffset = (static_cast<int64_t>(row) * static_cast<int64_t>(gCol) + static_cast<int64_t>(col));
        int calcWidthExr = Ceil(width, c0Size_);
        int calcHeightExr = Ceil(height, BLOCK_CUBE);

#if __CCE_AICORE__ == 200
        // set2d, pad tail zero
        if (height % BLOCK_CUBE != 0) {
            int64_t repeat = calcWidthExr * calcHeightExr;
            if constexpr (IsSameType<SrcT, int8_t>::value) {
                LocalTensor<int16_t> tmp = dst.template ReinterpretCast<int16_t>();
                InitConstValueParams<int16_t> initConstValueParams;
                initConstValueParams.repeatTimes = static_cast<int16_t>(repeat);
                initConstValueParams.initValue = 0;
                InitConstValue(tmp, initConstValueParams);
            } else {
                InitConstValueParams<SrcT> initConstValueParams;
                initConstValueParams.repeatTimes = static_cast<int16_t>(repeat);
                initConstValueParams.initValue = 0;
                InitConstValue(dst, initConstValueParams);
            }
            PipeBarrier<PIPE_MTE2>();
        }
#endif

        // gCol unaligned, can not use dma copy repeat stride
        int tail = width % c0Size_;
        if (tail) {
            // tail elements that need to be pad zero
            int blockLen = calcWidthExr * (c0Size_ * sizeof(SrcT) / DEFAULT_C0_SIZE);

            // gm->l1
            int src_gap = gCol * sizeof(SrcT) / ONE_BLK_SIZE - 1;
            if (gCol % c0Size_ || src_gap >= UINT16_MAX) {
                // each block len is only 32B
                for (auto i = 0; i < calcWidth; i++) {
                    for (auto j = 0; j < height; j++) {
                        DataCopy(dst[dstOffset + i * calcHeightExr * BLOCK_CUBE * c0Size_ + j * c0Size_],
                                 src[srcOffset + j * gCol + i * c0Size_], { 1, 1, 0, 0 });
                    }
                }
            } else {
                // data copy stride is aligned
                for (auto i = 0; i < calcWidth; i++) {
                    DataCopy(dst[dstOffset], src[srcOffset],
                             { static_cast<uint16_t>(height), 1, static_cast<uint16_t>(src_gap), 0 });
                    dstOffset += calcHeightExr * BLOCK_CUBE * c0Size_;
                    srcOffset += c0Size_;
                }
            }

            // tail gm->ub pad zero, and then ub->l1
            int32_t tileHeight;
            if (IsTranspose()) {
                tileHeight = GetStaticTileHeight<true>();
            } else {
                tileHeight = GetStaticTileHeight<false>();
            }
            auto size = tileHeight * ONE_BLK_SIZE / sizeof(SrcT);

            LocalTensor<SrcT> trans;
            trans = MATMUL_MODULE(LocalWorkspace)->GetND2NZWorkspace(0).template ReinterpretCast<SrcT>();
            trans.SetSize(size);

            int64_t tailSrcoffset = (int64_t)row * (int64_t)gCol + (int64_t)col + (int64_t)calcWidth * (int64_t)c0Size_;

            // gm->ub
            for (auto i = 0; i < height; ++i) {
                DataCopy(trans[i * c0Size_], src[tailSrcoffset], { 1, 1, 0, 0 });
                tailSrcoffset += gCol;
            }

            event_t eventIDMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
            SetFlag<HardEvent::MTE2_V>(eventIDMte2ToV);
            WaitFlag<HardEvent::MTE2_V>(eventIDMte2ToV);

            // tail pad zero
            uint64_t mask[2];
            if constexpr (IsSameType<SrcT, int8_t>::value) {
                tail = Ceil(tail, 2);
            }
            uint16_t mask_tail = ~((1 << tail) - 1);
            uint64_t masktail = mask_tail;
            mask[0] = masktail + (masktail << 16) + (masktail << 32) + (masktail << 48);
            mask[1] = mask[0];
            if (masktail != 0) {
                if constexpr (IsSameType<SrcT, int8_t>::value) {
                    LocalTensor<int16_t> tmpTrans = trans.template ReinterpretCast<int16_t>();
                    Duplicate(tmpTrans, static_cast<int16_t>(0), mask, Ceil(height, 8), 1, 8);
                } else {
                    Duplicate(trans, static_cast<SrcT>(0), mask, Ceil(height, 8), 1, 8);
                }
            }

            event_t eventIDVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
            SetFlag<HardEvent::V_MTE3>(eventIDVToMte3);
            WaitFlag<HardEvent::V_MTE3>(eventIDVToMte3);

            // ub->l1
            int heightAlignBlock = Ceil(height, BLOCK_CUBE);
            int tailDstOffset = heightAlignBlock * BLOCK_CUBE * c0Size_ * calcWidth;
            DataCopy(dst[tailDstOffset], trans, { static_cast<uint16_t>(height), 1, 0, 0 });
        } else {
            int src_gap = gCol * sizeof(SrcT) / ONE_BLK_SIZE - 1;
            if (gCol % c0Size_ != 0 || src_gap >= UINT16_MAX) {
                int64_t oriSrcOffset = srcOffset;
                int oriDstOffset = dstOffset;
                // each block len is only 32B
                for (int i = 0; i < calcWidth; i++) {
                    for (int j = 0; j < height; j++) {
                        DataCopy(dst[dstOffset], src[srcOffset], { 1, 1, 0, 0 });
                        dstOffset += c0Size_;
                        srcOffset += gCol;
                    }
                    srcOffset = oriSrcOffset + (i + 1) * c0Size_;
                    dstOffset = oriDstOffset + (i + 1) * calcHeightExr * BLOCK_CUBE * c0Size_;
                }
            } else {
                // data copy stride is aligned
                if constexpr (INPUT_TYPE::layout == LayoutMode::NORMAL) {
                    int32_t loop = height / MAX_BLOCK_COUNT_SIZE;
                    int32_t loopTail = height % MAX_BLOCK_COUNT_SIZE;
                    for (int i = 0; i < calcWidth; i++) {
                        int32_t dstOffsetTmp = dstOffset;
                        int32_t srcOffsetTmp = srcOffset;
                        for (int i = 0; i < loop; ++i) {
                            DataCopy(
                                dst[dstOffsetTmp], src[srcOffsetTmp],
                                { static_cast<uint16_t>(MAX_BLOCK_COUNT_SIZE), 1, static_cast<uint16_t>(src_gap), 0 });
                            dstOffsetTmp += MAX_BLOCK_COUNT_SIZE * c0Size_;
                            srcOffsetTmp += MAX_BLOCK_COUNT_SIZE * gCol;
                        }
                        if (loopTail) {
                            DataCopy(dst[dstOffsetTmp], src[srcOffsetTmp],
                                     { static_cast<uint16_t>(loopTail), 1, static_cast<uint16_t>(src_gap), 0 });
                        }
                        dstOffset += calcHeightExr * BLOCK_CUBE * c0Size_;
                        srcOffset += c0Size_;
                    }
                } else {
                    for (int i = 0; i < calcWidth; i++) {
                        DataCopy(dst[dstOffset], src[srcOffset],
                                 { static_cast<uint16_t>(height), 1, static_cast<uint16_t>(src_gap), 0 });
                        dstOffset += calcHeightExr * BLOCK_CUBE * c0Size_;
                        srcOffset += c0Size_;
                    }
                }
            }
            event_t eventIDMte2ToMte1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE1));
            SetFlag<HardEvent::MTE2_MTE1>(eventIDMte2ToMte1);
            WaitFlag<HardEvent::MTE2_MTE1>(eventIDMte2ToMte1);
            event_t eventIDMte1ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE1_MTE2));
            SetFlag<HardEvent::MTE1_MTE2>(eventIDMte1ToMte2);
            WaitFlag<HardEvent::MTE1_MTE2>(eventIDMte1ToMte2);
        }
    }

    __aicore__ inline void CopyND2NZ(const LocalTensor<SrcT>& dst, const GlobalTensor<SrcT>& src, int row, int col,
                                     int height, int width, int gCol, int ndNum = 1, bool kAlignToC0Size = false)
    {
        LocalTensor<SrcT> transTensor;
        transTensor = MATMUL_MODULE(LocalWorkspace)->GetWorkspaceWithOffset(0).template ReinterpretCast<SrcT>();
        transTensor.SetSize(MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetTransLength());
        LocalTensor<SrcT> trans;
        trans = MATMUL_MODULE(LocalWorkspace)->GetWorkspaceWithOffset(
            MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetTransLength()).template ReinterpretCast<SrcT>();
        trans.SetSize(MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetTransLength());

        auto srcOffset = ((int64_t)row * (int64_t)gCol + (int64_t)col);

        bool isBankConflict = Ceil(width, c0Size_) * 32 % 512 == 0 && Ceil(width, c0Size_) < 32 ? true : false;

        int calcHigh = Ceil(height, BLOCK_CUBE);
        auto enQueEvtID = GetTPipePtr()->FetchEventID(HardEvent::V_MTE2);
        SetFlag<HardEvent::V_MTE2>(enQueEvtID);
        WaitFlag<HardEvent::V_MTE2>(enQueEvtID);
        int calcWidth = CopyNDBlock(transTensor, src, srcOffset, height, width, gCol, isBankConflict);
        int padWidth = isBankConflict ? calcWidth + 1 : calcWidth;
        int size = calcHigh * padWidth * BLOCK_CUBE * c0Size_ / AuxGetFactor<SrcT>();
        ;

        transTensor.SetSize(size);
        trans.SetSize(size);
        const_cast<LocalTensor<SrcT>&>(dst).SetSize(size);

        NDPadZeros(transTensor, height, padWidth, gCol, width, isBankConflict);
        NDTrans2NZ(trans, transTensor, calcHigh, calcWidth, isBankConflict);

        event_t eventIDVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
        SetFlag<HardEvent::V_MTE3>(eventIDVToMte3);
        WaitFlag<HardEvent::V_MTE3>(eventIDVToMte3);
        DataCopy(dst, trans, size);
        enQueEvtID = GetTPipePtr()->FetchEventID(HardEvent::MTE3_V);
        SetFlag<HardEvent::MTE3_V>(enQueEvtID);
        WaitFlag<HardEvent::MTE3_V>(enQueEvtID);
    };
#endif

private:
#if __CCE_AICORE__ < 220
    int32_t orgHeight_;  // or M
    int32_t orgWidth_;   // or K
    int32_t baseHeight_; // or baseK
    int32_t baseWidth_;  // or baseM
    int32_t stepCol_;
#endif
    constexpr static int32_t c0Size_ = AuxGetC0Size<SrcT>();

    __aicore__ inline bool IsTranspose()
    {
        if constexpr(INPUT_TYPE::TAG == InputTypeTag::A) {
            return MATMUL_MODULE(MatmulShapeInfo)->IsTransposeA();
        } else {
            return MATMUL_MODULE(MatmulShapeInfo)->IsTransposeB();
        }
    }

    __aicore__ inline void CopyNZ2NZ(const LocalTensor<TransT>& dst, const GlobalTensor<SrcT>& src, const int32_t row,
                                     const int32_t col, const int32_t height, const int32_t width, const int32_t gRow,
                                     const bool kAlignToC0Size = false)
    {
        ASCENDC_ASSERT((gRow >= height), {
            KERNEL_LOG(
                KERNEL_ERROR,
                "NZ2NZ height larger than origin matrix height, gRow is %d, which should be no less than height %d.",
                gRow, height);
        });
        int32_t alignedGRow = Ceil(gRow, BLOCK_CUBE) * BLOCK_CUBE;
        int64_t srcOffset = (int64_t)row * (int64_t)c0Size_ + (int64_t)col * (int64_t)alignedGRow;
        // height direction need to be 16 aligned
        auto alignHeight = Ceil(height, BLOCK_CUBE) * BLOCK_CUBE;
        int32_t blockLen = alignHeight * c0Size_ * sizeof(TransT) / ONE_BLK_SIZE;
        int32_t srcStride = (alignedGRow - alignHeight) * (c0Size_ * sizeof(TransT) / ONE_BLK_SIZE);
        if constexpr (IsSameTypeV<TransT, int4b_t>) {
            blockLen /= INT4_TWO;
            srcStride /= INT4_TWO;
        }
        if (srcStride >= UINT16_MAX) {
            for (int32_t i = 0; i < Ceil(width, c0Size_); ++i) {
                DataCopy(dst[i * alignHeight * c0Size_], src[srcOffset + i * gRow * c0Size_],
                         { 1, static_cast<uint16_t>(blockLen), 0, 0 });
            }
        } else {
            uint16_t nburst = Ceil(width, c0Size_);
            int32_t dstStride = 0;
            if constexpr (IsSameTypeV<TransT, int8_t>) {
                if (kAlignToC0Size) {
                    auto alignHeightC0Size = Ceil(height, c0Size_) * c0Size_;
                    dstStride = alignHeightC0Size - alignHeight;
                }
            }
            DataCopy(dst, src[srcOffset],
                     { nburst, static_cast<uint16_t>(blockLen), static_cast<uint16_t>(srcStride),
                       static_cast<uint16_t>(dstStride) });
        }
    };

    __aicore__ inline void CopyNZ2NZ(const LocalTensor<TransT>& dst, const LocalTensor<SrcT>& src, const int32_t row,
                                     const int32_t col, const int32_t height, const int32_t width, const int32_t gRow)
    {
        ASCENDC_ASSERT((gRow >= height), {
            KERNEL_LOG(KERNEL_ERROR, "gRow is %d, which should be no less than height %d.", gRow, height);
        });
        int32_t srcOffset = row * c0Size_ + col * gRow;
        // height direction need to be 16 aligned
        auto alignHeight = (height + 15) / 16 * 16;
        int32_t blockLen = alignHeight * c0Size_ * sizeof(TransT) / ONE_BLK_SIZE;
        int32_t srcStride = (gRow - alignHeight) * (c0Size_ * sizeof(TransT) / ONE_BLK_SIZE);

        if (srcStride >= UINT16_MAX) {
            for (int32_t i = 0; i < width / c0Size_; ++i) {
                DataCopy(dst[i * alignHeight * c0Size_], src[srcOffset + i * gRow * c0Size_],
                         { 1, static_cast<uint16_t>(blockLen), 0, 0 });
            }
        } else {
            DataCopy(dst, src[srcOffset],
                     { static_cast<uint16_t>(width / c0Size_), static_cast<uint16_t>(blockLen),
                       static_cast<uint16_t>(srcStride), 0 });
        }
    };

#if __CCE_AICORE__ < 220
    template <typename T>
    __aicore__ inline int CopyNDBlock(const LocalTensor<T>& transTensor, const GlobalTensor<T>& src, int64_t srcOffset,
                                      int height, int width, int gCol, bool isBankConflict)
    {
        ASCENDC_ASSERT((gCol >= width),
                       { KERNEL_LOG(KERNEL_ERROR, "gCol is %d, which should be no less than %d.", gCol, width); });
        int calcWidth = width / c0Size_; // cube block numbers that do not need to be pad zero
        int c0Size = B16_C0SIZE;
        if constexpr (sizeof(T) == sizeof(float)) {
            c0Size = B32_C0SIZE;
        } else if (sizeof(T) == sizeof(int8_t)) {
            c0Size = B8_C0SIZE;
        }

        // gCol unaligned
        if (gCol % c0Size) {
            calcWidth = Ceil<int32_t>(CeilAlign<int32_t>(width, c0Size), c0Size_);
            int blockLen = CeilAlign<int32_t>(width, c0Size) * sizeof(T) / DEFAULT_C0_SIZE;
            int dstOffset = 0;
            int BankConflictPadSize = isBankConflict ? (32 / sizeof(T)) : 0;

            // data copy stride is unaligned, need to copy line by line
            for (int i = 0; i < height; i++) {
                DataCopy(transTensor[dstOffset], src[srcOffset], { 1, static_cast<uint16_t>(blockLen), 0, 0 });
                dstOffset += (CeilAlign<int32_t>(width, c0Size) + BankConflictPadSize);
                srcOffset += gCol;
            }

            auto enQueEvtID = GetTPipePtr()->FetchEventID(HardEvent::MTE2_V);
            SetFlag<HardEvent::MTE2_V>((event_t)enQueEvtID);
            WaitFlag<HardEvent::MTE2_V>((event_t)enQueEvtID);
        } else {
            int srcStride = (gCol - width) * sizeof(T) / ONE_BLK_SIZE;
            int blocklen = Ceil<int32_t>(width * sizeof(T), ONE_BLK_SIZE);
            calcWidth = Ceil<int32_t>(CeilAlign<int32_t>(width, c0Size), c0Size_);
            if (srcStride >= UINT16_MAX) {
                int dstOffset = isBankConflict ? (width + c0Size) : width;
                for (int i = 0; i < height; ++i) {
                    DataCopy(transTensor[i * dstOffset], src[srcOffset], { 1, static_cast<uint16_t>(blocklen), 0, 0 });
                    srcOffset += gCol;
                }
            } else {
                uint16_t dstStride = isBankConflict ? 1 : 0;
                int loopNum = Ceil<int32_t>(static_cast<uint16_t>(height), MAX_BLOCK_COUNT_SIZE);
                int tailCount = static_cast<uint16_t>(height) % MAX_BLOCK_COUNT_SIZE;
                for (int i = 0; i < loopNum; ++i) {
                    uint16_t blockCount = (i == loopNum - 1) ? tailCount : MAX_BLOCK_COUNT_SIZE;
                    DataCopy(
                        transTensor[i * MAX_BLOCK_COUNT_SIZE * blocklen * ONE_BLK_SIZE / sizeof(T)],
                        src[srcOffset + i * MAX_BLOCK_COUNT_SIZE * blocklen * ONE_BLK_SIZE / sizeof(T)],
                        { blockCount, static_cast<uint16_t>(blocklen), static_cast<uint16_t>(srcStride), dstStride });
                }
            }
            auto enQueEvtID = GetTPipePtr()->FetchEventID(HardEvent::MTE2_V);
            SetFlag<HardEvent::MTE2_V>((event_t)enQueEvtID);
            WaitFlag<HardEvent::MTE2_V>((event_t)enQueEvtID);
        }
        return calcWidth;
    }

    template <class T>
    __aicore__ inline void NDPadZeros(LocalTensor<T>& dst, int height, int calcWidth, int gCol, int width,
                                      bool isBankConflict)
    {
        if (gCol % BLOCK_CUBE) {
            int tail = width % c0Size_;
            // tail pad zero
            if (tail) {
                auto offset = width / c0Size_ * c0Size_;
                uint64_t mask[2];
                if constexpr (IsSameType<SrcT, int8_t>::value) {
                    tail = Ceil(tail, 2);
                    offset /= 2;
                }
                uint16_t mask_tail = ~((1 << tail) - 1);
                uint64_t masktail = mask_tail;
                mask[0] = masktail + (masktail << 16) + (masktail << 32) + (masktail << 48);
                mask[1] = mask[0];
                int stride = calcWidth * (c0Size_ * sizeof(T) / DEFAULT_C0_SIZE);
                int32_t totalRep = Ceil(height, 8);
                if (masktail != 0) {
                    if constexpr (IsSameType<SrcT, int8_t>::value) {
                        LocalTensor<int16_t> tmpTransTensor = dst.template ReinterpretCast<int16_t>();
                        if (stride < 32) {
                            if (totalRep <= MAX_REPEAT_TIMES) {
                                Duplicate(tmpTransTensor[offset], (int16_t)0, mask, Ceil(height, 8), stride,
                                          8 * stride);
                            } else {
                                int32_t highBlock = totalRep / MAX_REPEAT_TIMES;
                                int32_t highTail = totalRep % MAX_REPEAT_TIMES;
                                int64_t dstOffset = calcWidth * BLOCK_CUBE * 8 * MAX_REPEAT_TIMES;
                                for (int32_t idx = 0; idx < highBlock; ++idx) {
                                    Duplicate(tmpTransTensor[offset], (int16_t)0, mask, MAX_REPEAT_TIMES, stride,
                                              8 * stride);
                                    offset += dstOffset;
                                }
                                if (highTail) {
                                    Duplicate(tmpTransTensor[offset], (int16_t)0, mask, highTail, stride, 8 * stride);
                                }
                            }
                        } else {
                            for (int32_t i = 0; i < totalRep; ++i) {
                                Duplicate(tmpTransTensor[offset], (int16_t)0, mask, 1, stride, 0);
                                offset += stride * BLOCK_CUBE;
                            }
                        }
                    } else {
                        Duplicate(dst[offset], (T)0, mask, totalRep, stride, 8 * stride);
                    }
                    PipeBarrier<PIPE_V>();
                }
            }
        }
        // If the value of high is not an integer multiple of 16, add 0.
        int tailHigh = height % BLOCK_CUBE;
        if (tailHigh) {
            auto dstOffset = height * calcWidth * BLOCK_CUBE;
            if constexpr (IsSameType<SrcT, int8_t>::value) {
                LocalTensor<int16_t> tmpDst = dst.template ReinterpretCast<int16_t>();
                Duplicate(tmpDst[dstOffset], (int16_t)0, (BLOCK_CUBE - tailHigh) * calcWidth * BLOCK_CUBE);
            } else {
                Duplicate(dst[dstOffset], (T)0, (BLOCK_CUBE - tailHigh) * calcWidth * BLOCK_CUBE);
            }
        }
    }

    __aicore__ inline void NDTrans2NZ(LocalTensor<SrcT>& dst, LocalTensor<SrcT>& src, int calcHigh, int calcWidth,
                                      bool isBankConflict)
    {
        // Use Muls, convert to NZ format
        if constexpr (IsSameType<SrcT, int8_t>::value) {
            struct UnaryRepeatParams intriParams;
            uint64_t mask[2] = { uint64_t(-1), uint64_t(-1) };
            int blkStride = isBankConflict ? calcWidth + 1 : calcWidth;
            intriParams.dstBlkStride = (c0Size_ * sizeof(SrcT) / DEFAULT_C0_SIZE);
            intriParams.srcBlkStride = blkStride * (c0Size_ * sizeof(SrcT) / DEFAULT_C0_SIZE);
            intriParams.dstRepStride = intriParams.dstBlkStride * DEFAULT_BLK_NUM;
            intriParams.srcRepStride = intriParams.srcBlkStride * DEFAULT_BLK_NUM;
            int dstOffset = 0;
            int srcOffset = 0;
            // ensure rep stride be less than 256
            constexpr int maxSrcBlkStride = 32;
            LocalTensor<int16_t> tmpSrc = src.template ReinterpretCast<int16_t>();
            LocalTensor<int16_t> tmpDst = dst.template ReinterpretCast<int16_t>();
            if (intriParams.srcBlkStride >= maxSrcBlkStride) {
                intriParams.dstBlkStride = 1;
                intriParams.srcBlkStride = 1;
                mask[0] = (1 << BLOCK_CUBE) - 1;
                mask[1] = 0;
                SetVectorMask<int16_t>(mask[1], mask[0]);
                for (int i = 0; i < calcWidth; i++) {
                    for (int j = 0; j < calcHigh * BLOCK_CUBE; ++j) {
                        dstOffset = i * calcHigh * CUBE_MAX_SIZE + j * BLOCK_CUBE;
                        srcOffset = j * blkStride * BLOCK_CUBE + i * BLOCK_CUBE;
                        Muls<int16_t, false>(tmpDst[dstOffset], tmpSrc[srcOffset], (int16_t)1, mask, 1, intriParams);
                    }
                }
            } else {
                SetVectorMask<int16_t>(mask[1], mask[0]);
                int32_t totalRepTimes = 2 * calcHigh;
                int32_t highBlock = totalRepTimes / MAX_REPEAT_TIMES;
                int32_t highTail = totalRepTimes % MAX_REPEAT_TIMES;
                for (int i = 0; i < calcWidth; i++) {
                    dstOffset = i * calcHigh * CUBE_MAX_SIZE;
                    srcOffset = i * BLOCK_CUBE;
                    for (int32_t idx = 0; idx < highBlock; ++idx) {
                        Muls<int16_t, false>(tmpDst[dstOffset], tmpSrc[srcOffset], (int16_t)1, mask, MAX_REPEAT_TIMES,
                                             intriParams);
                        dstOffset += BLOCK_CUBE * MAX_REPEAT_TIMES * 8;
                        srcOffset += calcWidth * BLOCK_CUBE * MAX_REPEAT_TIMES * 8;
                    }
                    if (highTail) {
                        Muls<int16_t, false>(tmpDst[dstOffset], tmpSrc[srcOffset], (int16_t)1, mask, highTail,
                                             intriParams);
                    }
                }
            }
        } else {
            const int c0Count = AscendCUtils::GetC0Count(sizeof(SrcT));
            struct UnaryRepeatParams intriParams;
            uint64_t mask[2] = { uint64_t(-1), uint64_t(-1) };
            int32_t padBlock = 1;
            if constexpr (IsSameTypeV<TransT, half> && IsSameTypeV<SrcT, int8_t>) {
                padBlock = 2;
            }
            int blkStride = isBankConflict ? calcWidth + padBlock : calcWidth;
            intriParams.dstBlkStride = (BLOCK_CUBE * sizeof(SrcT) / DEFAULT_C0_SIZE);
            intriParams.srcBlkStride = blkStride * BLOCK_CUBE * sizeof(SrcT) / DEFAULT_C0_SIZE;
            intriParams.dstRepStride = intriParams.dstBlkStride * DEFAULT_BLK_NUM;
            intriParams.srcRepStride = intriParams.srcBlkStride * DEFAULT_BLK_NUM;
            int dstOffset = 0;
            int srcOffset = 0;
            // ensure rep stride be less than 256
            constexpr int maxSrcBlkStride = 32;
            if (intriParams.srcBlkStride >= maxSrcBlkStride) {
                intriParams.dstBlkStride = 1;
                intriParams.srcBlkStride = 1;
                mask[0] = (1 << BLOCK_CUBE) - 1;
                mask[1] = 0;
                SetVectorMask<SrcT>(mask[1], mask[0]);
                for (int i = 0; i < calcWidth; i++) {
                    for (int j = 0; j < calcHigh * BLOCK_CUBE; ++j) {
                        dstOffset = i * calcHigh * CUBE_MAX_SIZE + j * BLOCK_CUBE;
                        srcOffset = j * blkStride * BLOCK_CUBE + i * BLOCK_CUBE;
                        Muls<SrcT, false>(dst[dstOffset], src[srcOffset], (SrcT)1, mask, 1, intriParams);
                        if constexpr (sizeof(SrcT) == sizeof(float)) {
                            Muls<SrcT, false>(dst[dstOffset + c0Count], src[srcOffset + c0Count], (SrcT)1, mask, 1,
                                              intriParams);
                        }
                    }
                }
            } else {
                SetVectorMask<SrcT>(mask[1], mask[0]);
                for (int i = 0; i < calcWidth; i++) {
                    dstOffset = i * calcHigh * CUBE_MAX_SIZE;
                    srcOffset = i * BLOCK_CUBE;
                    Muls<SrcT, false>(dst[dstOffset], src[srcOffset], (SrcT)1, mask, 2 * calcHigh, intriParams);
                    if constexpr (sizeof(SrcT) == sizeof(float)) {
                        Muls<SrcT, false>(dst[dstOffset + c0Count], src[srcOffset + c0Count], (SrcT)1, mask,
                                          2 * calcHigh, intriParams);
                    }
                }
            }
        }
    }

#endif

#if __CCE_AICORE__ >= 220
    template <typename DataType>
    __aicore__ inline void StaticPadNd2Nz(const LocalTensor<DataType>& dst, const int32_t staticHeight,
        const int32_t staticWidth, const int32_t tileHeight, const int32_t tileWidth)
    {
        if constexpr (DoMatmulNorm(MM_CFG) || DoMatmulBasicBlock(MM_CFG) || DoMatmulSpecialBasicBlock(MM_CFG)) {
            int32_t tileWidthC0 = Ceil(tileWidth, c0Size_);
            int32_t staticWidthC0 = Ceil(staticWidth, c0Size_);
            // pad left bottom area of src.
            if (tileHeight < staticHeight) {
                InitConstValueParams<DataType> initConstValueParams;
                initConstValueParams.repeatTimes = tileWidthC0;
                initConstValueParams.blockNum = staticHeight - tileHeight;
                initConstValueParams.dstGap = tileHeight;
                initConstValueParams.initValue = 0;
                InitConstValue(dst[tileHeight * c0Size_], initConstValueParams);
            }
            // pad right area of src
            if (tileWidthC0 < staticWidthC0) {
                InitConstValueParams<DataType> initConstValueParams;
                initConstValueParams.repeatTimes = 1;
                initConstValueParams.blockNum = (staticWidthC0 - tileWidthC0) * staticHeight;
                initConstValueParams.dstGap = 0;
                initConstValueParams.initValue = 0;
                InitConstValue(dst[tileWidthC0 * staticHeight * c0Size_], initConstValueParams);
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            using params = InitConstValueParams<DataType>;
            InitConstValue(dst,
                params{ 1, static_cast<uint16_t>(staticHeight * staticWidth * sizeof(DataType) / ONE_BLK_SIZE), 0, 0 });
        }
    }

    __aicore__ inline void CopyND2NZ(const LocalTensor<TransT>& dst, const GlobalTensor<SrcT>& src, const int32_t row,
                                     const int32_t col, const int32_t height, const int32_t width, const int32_t gCol,
                                     const int32_t ndNum = 1, const int32_t srcNdMatrixStride = 0,
                                     const int32_t dstNzMatrixStride = 0, const bool kAlignToC0Size = false)
    {
        ASCENDC_ASSERT((row >= 0), { KERNEL_LOG(KERNEL_ERROR, "row is %d, which should be no less than 0.", row); });
        ASCENDC_ASSERT((col >= 0), { KERNEL_LOG(KERNEL_ERROR, "col is %d, which should be no less than 0.", col); });
        ASCENDC_ASSERT((height > 0),
                       { KERNEL_LOG(KERNEL_ERROR, "height is %d, which should be no less than 0.", height); });
        ASCENDC_ASSERT((width > 0),
                       { KERNEL_LOG(KERNEL_ERROR, "width is %d, which should be no less than 0.", width); });
        ASCENDC_ASSERT((gCol >= width), {
            KERNEL_LOG(
                KERNEL_ERROR,
                "ND2NZ width larger than origin matrix width, gCol is %d, which should be no less than width %d.", gCol,
                width);
        });
        int32_t dstNzC0Stride = 0;
        if constexpr (IsStaticPaddingEnable(MM_CFG)) {
            int32_t tileHeight = GetStaticTileHeight<INPUT_TYPE::isTrans>();
            int32_t tileWidth = GetStaticTileWidth<INPUT_TYPE::isTrans>();
            if (tileHeight != height || tileWidth != width) {
                StaticPadNd2Nz<TransT>(dst, tileHeight, tileWidth, height, width);
                dstNzC0Stride = tileHeight;
            }
        }
        int64_t srcOffset;
        if constexpr (IsSameTypeV<TransT, int4b_t>) {
            srcOffset = ((int64_t)row * (int64_t)gCol * INT4_TWO + (int64_t)col);
        } else {
            srcOffset = ((int64_t)row * (int64_t)gCol + (int64_t)col);
        }
        Nd2NzParams nd2nzParams;
        nd2nzParams.ndNum = ndNum;
        nd2nzParams.nValue = height;
        nd2nzParams.dValue = width;
        nd2nzParams.srcNdMatrixStride = srcNdMatrixStride;
        nd2nzParams.srcDValue = gCol;

        if (dstNzC0Stride) {
            nd2nzParams.dstNzC0Stride = dstNzC0Stride;
        } else {
            // when k is row(height) axis, int8 type gm->l1 nd2nz should be aligned to 32(c0Size)
            // while float/half type should be aligned to 16
            if (kAlignToC0Size) {
                nd2nzParams.dstNzC0Stride = Ceil(height, c0Size_) * c0Size_;
            } else {
                nd2nzParams.dstNzC0Stride = Ceil(height, BLOCK_CUBE) * BLOCK_CUBE;
            }
        }
        nd2nzParams.dstNzNStride = 1;
        nd2nzParams.dstNzMatrixStride = dstNzMatrixStride;
#if __CCE_AICORE__ == 220
        if constexpr (!ToMatmulConfig(MM_CFG).intrinsicsCheck) {
            DataCopy(dst, src[srcOffset], nd2nzParams);
        } else {
            if (gCol >= UINT16_MAX) {
                nd2nzParams.nValue = 1;
                nd2nzParams.srcDValue = width;
                for (int32_t i = 0; i < height; ++i) {
                    DataCopy(dst[i * c0Size_], src[srcOffset + gCol * i], nd2nzParams);
                }
            } else {
                DataCopy(dst, src[srcOffset], nd2nzParams);
            }
        }
#else
        DataCopy(dst, src[srcOffset], nd2nzParams); // stride scope has increased
#endif
    }

#endif
};
}  // namespace Detail
}  // namespace Impl
}  // namespace AscendC
#endif // IMPL_MATMUL_MODULES_STAGE_BATCH_COPY_CUBE_IN_BATCH_DATA_COPY_WRAPPER_H
